Одной из самых интересных задач, с которыми могут справляться нейронные сети, является генерация новых объектов. В этом уроке мы:
Идея GAN довольно проста. Предположим, мы пытаемся обучать сеть (назовем ее сеть 1), которая умеет создавать новые объекты из шума. Когда мы имеем подобную сеть, возникает вопрос: как наша сеть поймет, что она создала хорошие картинки, которые похожи на реальные?
Есть несколько способов решить эту проблему. Например, сравнивать распределения между объектами, генерируемыми сетями, и объектами из реальной природы. Однако можно решить данную проблему сильно проще: давайте кто-то будет сравнивать сгенерированные объекты с настоящими. И этим «кто-то» может быть вторая нейронная сеть (сеть 2), решающая задачу классификации, настоящий перед ней объект или искусственно сгенерированный. Первая сеть обычно называется генератор , а вторая дискриминатор.
Соответственно, наша GAN-модель — это сочетание 2 нейронных сетей, генератора и дискриминатора, которые соревнуются и пытаются обойти друг друга, создавая все более реалистичные изображения.
Рассмотрим пример: представим, что мы искусные мастера, занимающиеся подделкой древностей. Наши работы мы продаем в музеи за большие деньги. В музее есть искусствовед, который может отличать настоящие древности от подделок.
Как же будет происходить наша работа? Вначале, если мы новички, все наши произведения будут отличаться от настоящих древностей, особенно для профессионала. Однако постепенно мы начнем узнавать, что не так с нашими подделками, и делать их все лучше. Искусствовед начнет ошибаться. Но искусствовед в какой-то момент поймет, что приносимые ему подделки все менее и менее отличимы от настоящих древностей. Он тоже начнет учиться отличать хорошие подделки от настоящих объектов. Таким образом мы, мастера, будем учиться подделывать все точнее и точнее, а искусствовед будет все лучше и лучше отличать подделки от реальных древностей.
В мире генерации картинок с помощью GAN все то же самое, только наш нечестный мастер (генератор) и искусствовед (дискриминатор) — нейронные сети.

Мы разобрали, что такое GAN в реальной жизни, давайте теперь научимся строить их в Python. Для этого нам нужно выбрать данные и подключить все библиотеки.
Для работы в данном ноутбуке мы будем пользоваться библиотеками PyTorch и TorchVision в качестве инструмента работы с нейронными сетями.
Импортируем эти и другие необходимые модули.
import random
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
import torch
import torch.nn as nn
import torch.optim
import torchvision
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutilsЗафиксируем random seed, чтобы сделать наши эксперименты воспроизводимыми.
manualSeed = 42
random.seed(manualSeed)
torch.manual_seed(manualSeed)
print("Random Seed: ", manualSeed)Random Seed: 42
Для более быстрого обучения нейронных сетей в PyTorch
можно использовать видеокарту, поддерживающую технологию CUDA. Если на
вашем устройстве есть видеокарта, то ячейка ниже
поможет автоматически переключить вычисления на нее.
Если у вас нет видеокарты, не переживайте, вы можете
работать с данным ноутбуком, но вычисления будут происходить
медленнее.
Если вы работаете в Google Colab, не забудьте выбрать среду выполнения GPU.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')В качестве датасета мы выбрали датасет FashionMNIST. Датасет представляет собой черно-белые картинки размером 28 x 28 пикселей с изображением элементов одежды. Всего в датасете 10 классов.
Вам не придется отдельно скачивать данные, так как в
модуле torchvision уже представлен интерфейс работы с этим
датасетом.
Создадим классы Dataset и Dataloader для тренировочной и тестовой части нашего датасета.
Внимание! Если на вашу видеокарту не помещается
модель сделайте параметр batch_size меньше.
# Number of workers for dataloader
workers = 2
batch_size = 128
image_size = 32
transform=transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5), (0.5)),
])
dataset = torchvision.datasets.FashionMNIST(root='FashionMNIST', train=True,
download=True, transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
shuffle=True, num_workers=workers)Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to FashionMNIST/FashionMNIST/raw/train-images-idx3-ubyte.gz
{"model_id":"4026512827ee4a1dbeee26902c4840aa","version_major":2,"version_minor":0}Extracting FashionMNIST/FashionMNIST/raw/train-images-idx3-ubyte.gz to FashionMNIST/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to FashionMNIST/FashionMNIST/raw/train-labels-idx1-ubyte.gz
{"model_id":"47a7423e05934b0aadffa4681502db29","version_major":2,"version_minor":0}Extracting FashionMNIST/FashionMNIST/raw/train-labels-idx1-ubyte.gz to FashionMNIST/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to FashionMNIST/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
{"model_id":"8ea8c9bf901e45899bb8a7971c593cf7","version_major":2,"version_minor":0}Extracting FashionMNIST/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to FashionMNIST/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to FashionMNIST/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
{"model_id":"9fa510911e8c49da92908234ca837dec","version_major":2,"version_minor":0}Extracting FashionMNIST/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to FashionMNIST/FashionMNIST/raw
Посмотрим, как выглядит наш датасет.
def grid_visual(batch, n_pictures=64, label=''):
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title(label)
pictures = batch[0].to(device)[:n_pictures]
vis_grid = vutils.make_grid(pictures, padding=2, normalize=True).cpu()
vis_grid = np.transpose(vis_grid,(1,2,0))
plt.imshow(vis_grid)
real_batch = next(iter(dataloader))
grid_visual(real_batch, n_pictures=64, label='Trainign images')
Модель генератора мы будем обозначать через G(z), где z — латентный вектор, из которого происходит генерация. В нашем примере мы будем брать случайный шум в качестве этого вектора. Генератор принимает на вход латентный вектор и создает из него картинку. Соответственно, на выходе у нашего генератора будет трехмерный тензор (многомерная матрица), в котором размерности — длина изображения, ширина изображения и количество каналов (если мы хотим сделать картинку цветной, нам надо сгенерировать 3 канала для красного, зеленого и синего цветов).
Модель дискриминатора мы будем обозначать через D(x). Она принимает на вход картинку (все еще трехмерный тензор) и решает задачу бинарной классификации: определяет, является ли наша картинка настоящей или сгенерированной. Сгенерированные картинки мы будем обозначать через класс 0, настоящие — через класс 1.
Еще раз посмотрим на картинку с моделями, чтобы было понятнее.
Создадим модели генератора и дискриминатора. В качестве моделей мы будем использовать CNN-модели. Генератор и дискриминатор будут представлять из себя симметричные сетки с 4 блоками, состоящими из слоя convolution, слоя batch-normalization и функции активации.
В качестве латентного вектора мы будем использовать вектор случайного шума размерностью 100.
Инициализируем наши модели.
Источник:
https://medium.com/analytics-vidhya/deep-convolutional-generative-adversarial-network-4133bd4779ea
class Generator(nn.Module):
def __init__(self, n_channels=3, latent_size=128, size=64):
super(Generator, self).__init__()
self.seq = nn.Sequential(
nn.ConvTranspose2d( latent_size, size * 4, 4, 1, 0, bias=False),
nn.BatchNorm2d(size * 4),
nn.ReLU(True),
nn.ConvTranspose2d( size * 4, size * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(size * 2),
nn.ReLU(True),
nn.ConvTranspose2d( size * 2, size, 4, 2, 1, bias=False),
nn.BatchNorm2d(size),
nn.ReLU(True),
nn.ConvTranspose2d( size, n_channels, 4, 2, 1, bias=False),
nn.Tanh()
)
def forward(self, x):
return self.seq(x)
class Discriminator(nn.Module):
def __init__(self, size=64, n_channels=3):
super(Discriminator, self).__init__()
self.main = nn.Sequential(
nn.Conv2d(n_channels, size, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(size, size * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(size * 2),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(size * 2, size * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(size * 4),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(size * 4, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, input):
return self.main(input)#Создаем модель генератора
latent_size = 100
model_gener = Generator(n_channels=1, latent_size=latent_size, size=64).to(device)
model_generGenerator(
(seq): Sequential(
(0): ConvTranspose2d(100, 256, kernel_size=(4, 4), stride=(1, 1), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
(6): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(7): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(8): ReLU(inplace=True)
(9): ConvTranspose2d(64, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(10): Tanh()
)
)
#Создаем модель дискриминатора
model_disc = Discriminator(n_channels=1, size=64).to(device)
model_discDiscriminator(
(main): Sequential(
(0): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(1): LeakyReLU(negative_slope=0.2, inplace=True)
(2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(4): LeakyReLU(negative_slope=0.2, inplace=True)
(5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(7): LeakyReLU(negative_slope=0.2, inplace=True)
(8): Conv2d(256, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
(9): Sigmoid()
)
)
Рассмотрим, как мы будем обучать модель GAN. Для этого представим функции потерь для наших моделей.
Функция потерь для нашего генератора — бинарная кросс-энтропия от сгенерированных картинок:
$$ Loss_G = - \sum_{i=1}^n (y_i \log(\hat{y_i}) + (1 - y_i) \log(1-\hat{y_i})), $$ $$ \hat{y_i} = D(G(\mathbf{z}_i)). $$ Здесь yi — класс объекта для дискриминатора. У сгенерированных фейковых картинок метка 0, у настоящих — 1, zi — исходный латентный вектор для генерации изображения i, в нашем эксперименте мы будем обозначать его как случайный шум.
Однако поскольку наш генератор создает только ненастоящие картинки с классом 0, то мы можем упростить нашу функцию потерь:
$$ Loss_G = -\sum_{i=1}^n \log (1 - D(G(\mathbf{z}_i))). $$
Функция потерь модели дискриминатора — тоже бинарная кросс-энтропия:
$$ Loss_G = - \sum_{i=1}^n (y_i \log(\hat{D(\mathbf{x})}) + (1-y_i) \log(1 - \hat{D(\mathbf{x})})). $$ Здесь yi — истинный класс объекта (у нас 0 — фейковые картинки, 1 — настоящие), x — картинка, подаваемая на вход модели. Давайте немного преобразуем нашу функцию потерь.
Но в случае дискриминатора картинки могут быть как реальные, так и полученные от дискриминатора.
Модели мы будем обучать по очереди: сначала модель дискриминатора, потом генератор. Причина этого очень проста: если наш дискриминатор (искусствовед) ни на что не способен, генератор (создатель поддельных древностей) не поймет, сделал он что-то стоящее или нет. Одной итерацией мы будем называть сначала прогон данных через дискриминатор и оптимизацию параметров, а затем прогон данных через генератор и обучение генератора.
В качестве функции потерь мы будем использовать бинарную кросс-энтропию. В качестве optimizer мы возьмем Adam.
num_epochs = 20
lr = 0.0002
# создаем фиксированный вектор шума, из которого будем генерировать картинки
# чтобы оценить результат визуально
fixed_noise = torch.randn(64, latent_size, 1, 1, device=device)
criterion = nn.BCELoss()
optimizer_disc = torch.optim.Adam(model_disc.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_gener = torch.optim.Adam(model_gener.parameters(), lr=lr, betas=(0.5, 0.999))Запустим процесс обучения:
img_list = [] # сюда будем складывать картинки, чтобы потом посмотреть, как учился наш GAN
gener_losses = [] # сюда - loss Генератора для графика
disc_losses = [] # сюда - loss Дискриминатора
iter_ = 0
n_batches = len(dataloader)
for epoch in range(num_epochs):
for i, (real_images, _) in enumerate(dataloader):
#***************************************
# Обучаем Дискриминатор
#***************************************
model_disc.zero_grad()
real_images = real_images.to(device)
BS = real_images.size(0)
true_labels = torch.ones((BS,), dtype=torch.float, device=device)
# прогоняем реальные картинки через дискриминатор
pred_labels = model_disc(real_images).view(-1)
loss_disc_real = criterion(pred_labels, true_labels)
loss_disc_real.backward()
# прогоняем сгенерированные картинки
# генерим картинки
noise = torch.randn(BS, latent_size, 1, 1, device=device)
fake_images = model_gener(noise)
true_labels = torch.zeros((BS,), dtype=torch.float, device=device)
# прогоняем сгенерированные картинки через дискриминатор
pred_labels = model_disc(fake_images.detach()).view(-1)
loss_disc_fake = criterion(pred_labels, true_labels)
loss_disc_fake.backward()
# обучаем дискриминатор
loss_disc = loss_disc_real + loss_disc_fake
optimizer_disc.step()
#***************************************
# Обучаем Генератор
#***************************************
model_gener.zero_grad()
#Прогоняем сгенерированные картинки через дискриминатор, чтобы обучить генератор
true_labels = torch.ones((BS,), dtype=torch.float, device=device)
pred_labels = model_disc(fake_images).view(-1)
# обучаем Генератор
loss_gener = criterion(pred_labels, true_labels)
loss_gener.backward()
optimizer_gener.step()
# выводим результаты
if i % 50 == 0:
gener_losses.append(loss_gener.item())
disc_losses.append(loss_disc.item())
print(f'ep {epoch}; batch {i}/{n_batches}\t Loss D: {loss_disc.item()}\tLoss G: {loss_gener.item()}')
if (iter_ % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
with torch.no_grad():
fake_images = model_gener(fixed_noise).detach().cpu()
img_list.append(vutils.make_grid(fake_images, padding=2, normalize=True))
iter_ += 1ep 0; batch 0/469 Loss D: 1.4937667846679688 Loss G: 1.2093353271484375
ep 0; batch 50/469 Loss D: 0.013899954967200756 Loss G: 5.690629005432129
ep 0; batch 100/469 Loss D: 0.002782579977065325 Loss G: 6.704823017120361
ep 0; batch 150/469 Loss D: 0.005172377452254295 Loss G: 6.866236686706543
ep 0; batch 200/469 Loss D: 0.28504684567451477 Loss G: 4.555323600769043
ep 0; batch 250/469 Loss D: 0.7482448220252991 Loss G: 2.366978883743286
ep 0; batch 300/469 Loss D: 0.574029803276062 Loss G: 2.3843934535980225
ep 0; batch 350/469 Loss D: 0.5982410907745361 Loss G: 2.832051992416382
ep 0; batch 400/469 Loss D: 0.6306153535842896 Loss G: 3.3191730976104736
ep 0; batch 450/469 Loss D: 0.45480287075042725 Loss G: 1.8240249156951904
ep 1; batch 0/469 Loss D: 0.35172292590141296 Loss G: 2.3746097087860107
ep 1; batch 50/469 Loss D: 0.4249175786972046 Loss G: 2.311898708343506
ep 1; batch 100/469 Loss D: 0.5290405750274658 Loss G: 2.7114367485046387
ep 1; batch 150/469 Loss D: 0.9011371731758118 Loss G: 1.1337352991104126
ep 1; batch 200/469 Loss D: 0.525187611579895 Loss G: 2.3073439598083496
ep 1; batch 250/469 Loss D: 0.586367130279541 Loss G: 1.528287649154663
ep 1; batch 300/469 Loss D: 0.7664244771003723 Loss G: 1.1074213981628418
ep 1; batch 350/469 Loss D: 1.3846405744552612 Loss G: 4.797711372375488
ep 1; batch 400/469 Loss D: 0.608893632888794 Loss G: 1.8229165077209473
ep 1; batch 450/469 Loss D: 0.5552241802215576 Loss G: 2.0942234992980957
ep 2; batch 0/469 Loss D: 0.5245530605316162 Loss G: 1.8416471481323242
ep 2; batch 50/469 Loss D: 0.780532717704773 Loss G: 0.9159003496170044
ep 2; batch 100/469 Loss D: 0.5786683559417725 Loss G: 2.1480560302734375
ep 2; batch 150/469 Loss D: 0.574405312538147 Loss G: 1.7270660400390625
ep 2; batch 200/469 Loss D: 0.9334298372268677 Loss G: 1.0795753002166748
ep 2; batch 250/469 Loss D: 0.777777910232544 Loss G: 1.9880492687225342
ep 2; batch 300/469 Loss D: 0.7166554927825928 Loss G: 2.987734794616699
ep 2; batch 350/469 Loss D: 0.47685593366622925 Loss G: 2.0004026889801025
ep 2; batch 400/469 Loss D: 0.6247270703315735 Loss G: 3.148059368133545
ep 2; batch 450/469 Loss D: 0.6322021484375 Loss G: 1.877185583114624
ep 3; batch 0/469 Loss D: 0.6026877164840698 Loss G: 2.466195583343506
ep 3; batch 50/469 Loss D: 0.4423443675041199 Loss G: 3.358853578567505
ep 3; batch 100/469 Loss D: 0.42981183528900146 Loss G: 2.2605278491973877
ep 3; batch 150/469 Loss D: 0.2224334329366684 Loss G: 2.8951892852783203
ep 3; batch 200/469 Loss D: 1.0038373470306396 Loss G: 0.9826227426528931
ep 3; batch 250/469 Loss D: 0.26325875520706177 Loss G: 2.471280097961426
ep 3; batch 300/469 Loss D: 0.16706979274749756 Loss G: 2.667466402053833
ep 3; batch 350/469 Loss D: 0.20665283501148224 Loss G: 3.2433862686157227
ep 3; batch 400/469 Loss D: 0.934120774269104 Loss G: 1.9394758939743042
ep 3; batch 450/469 Loss D: 0.09475196897983551 Loss G: 3.3879261016845703
ep 4; batch 0/469 Loss D: 0.1090974435210228 Loss G: 3.457209825515747
ep 4; batch 50/469 Loss D: 0.1453503519296646 Loss G: 3.1694111824035645
ep 4; batch 100/469 Loss D: 0.1176614761352539 Loss G: 3.433608055114746
ep 4; batch 150/469 Loss D: 0.060029566287994385 Loss G: 4.032495975494385
ep 4; batch 200/469 Loss D: 0.09354156255722046 Loss G: 4.530811309814453
ep 4; batch 250/469 Loss D: 1.0029122829437256 Loss G: 1.405250072479248
ep 4; batch 300/469 Loss D: 0.6542681455612183 Loss G: 2.3534352779388428
ep 4; batch 350/469 Loss D: 0.27759355306625366 Loss G: 2.3478448390960693
ep 4; batch 400/469 Loss D: 0.30074432492256165 Loss G: 2.214177131652832
ep 4; batch 450/469 Loss D: 0.09932549297809601 Loss G: 3.4997546672821045
ep 5; batch 0/469 Loss D: 0.06577970832586288 Loss G: 3.9780333042144775
ep 5; batch 50/469 Loss D: 0.055192943662405014 Loss G: 3.9820094108581543
ep 5; batch 100/469 Loss D: 0.04721107706427574 Loss G: 4.303548336029053
ep 5; batch 150/469 Loss D: 0.03768550604581833 Loss G: 4.313530921936035
ep 5; batch 200/469 Loss D: 0.05080404132604599 Loss G: 4.052900314331055
ep 5; batch 250/469 Loss D: 0.029349416494369507 Loss G: 4.355489730834961
ep 5; batch 300/469 Loss D: 0.04063452035188675 Loss G: 4.434535026550293
ep 5; batch 350/469 Loss D: 0.027497118338942528 Loss G: 5.046237468719482
ep 5; batch 400/469 Loss D: 0.046557310968637466 Loss G: 4.5303449630737305
ep 5; batch 450/469 Loss D: 0.0424213632941246 Loss G: 4.430820465087891
ep 6; batch 0/469 Loss D: 0.02647353708744049 Loss G: 4.442694664001465
ep 6; batch 50/469 Loss D: 0.02873159945011139 Loss G: 4.5794782638549805
ep 6; batch 100/469 Loss D: 0.7022024393081665 Loss G: 1.6264762878417969
ep 6; batch 150/469 Loss D: 0.7014973163604736 Loss G: 3.3252336978912354
ep 6; batch 200/469 Loss D: 0.43892917037010193 Loss G: 2.860687017440796
ep 6; batch 250/469 Loss D: 0.46613621711730957 Loss G: 2.8642938137054443
ep 6; batch 300/469 Loss D: 0.08037455379962921 Loss G: 3.8944718837738037
ep 6; batch 350/469 Loss D: 0.3663017451763153 Loss G: 2.08122181892395
ep 6; batch 400/469 Loss D: 0.062243007123470306 Loss G: 3.813201904296875
ep 6; batch 450/469 Loss D: 0.047794975340366364 Loss G: 4.129014015197754
ep 7; batch 0/469 Loss D: 0.03491111099720001 Loss G: 4.679188251495361
ep 7; batch 50/469 Loss D: 0.03325328975915909 Loss G: 4.523951530456543
ep 7; batch 100/469 Loss D: 0.03305039554834366 Loss G: 4.503073692321777
ep 7; batch 150/469 Loss D: 0.016984522342681885 Loss G: 5.255037307739258
ep 7; batch 200/469 Loss D: 0.022073067724704742 Loss G: 5.318765163421631
ep 7; batch 250/469 Loss D: 0.028557300567626953 Loss G: 4.717236518859863
ep 7; batch 300/469 Loss D: 0.012888045981526375 Loss G: 5.520537376403809
ep 7; batch 350/469 Loss D: 0.023338302969932556 Loss G: 4.884881019592285
ep 7; batch 400/469 Loss D: 0.01693468913435936 Loss G: 5.203306198120117
ep 7; batch 450/469 Loss D: 0.014700938016176224 Loss G: 6.6988420486450195
ep 8; batch 0/469 Loss D: 0.020526738837361336 Loss G: 5.363239288330078
ep 8; batch 50/469 Loss D: 0.7189997434616089 Loss G: 1.7391115427017212
ep 8; batch 100/469 Loss D: 0.5776994228363037 Loss G: 1.601909875869751
ep 8; batch 150/469 Loss D: 0.480876624584198 Loss G: 2.3942770957946777
ep 8; batch 200/469 Loss D: 0.5589001774787903 Loss G: 3.8310256004333496
ep 8; batch 250/469 Loss D: 0.23903599381446838 Loss G: 2.351200580596924
ep 8; batch 300/469 Loss D: 0.13181641697883606 Loss G: 3.3789544105529785
ep 8; batch 350/469 Loss D: 0.06522852182388306 Loss G: 4.185437202453613
ep 8; batch 400/469 Loss D: 0.025554485619068146 Loss G: 5.011122703552246
ep 8; batch 450/469 Loss D: 0.01590690389275551 Loss G: 5.3681230545043945
ep 9; batch 0/469 Loss D: 0.01193674374371767 Loss G: 6.1686882972717285
ep 9; batch 50/469 Loss D: 0.01629747822880745 Loss G: 5.169459342956543
ep 9; batch 100/469 Loss D: 0.01244533434510231 Loss G: 5.613425254821777
ep 9; batch 150/469 Loss D: 0.006465879734605551 Loss G: 6.081326484680176
ep 9; batch 200/469 Loss D: 0.03582369163632393 Loss G: 5.196817398071289
ep 9; batch 250/469 Loss D: 0.021676043048501015 Loss G: 5.089450836181641
ep 9; batch 300/469 Loss D: 0.007049081847071648 Loss G: 6.588751316070557
ep 9; batch 350/469 Loss D: 0.019252467900514603 Loss G: 5.541248321533203
ep 9; batch 400/469 Loss D: 0.007997090928256512 Loss G: 5.836613655090332
ep 9; batch 450/469 Loss D: 0.00796957965940237 Loss G: 6.02266788482666
ep 10; batch 0/469 Loss D: 0.007169263903051615 Loss G: 6.460685729980469
ep 10; batch 50/469 Loss D: 0.01785324513912201 Loss G: 5.279462814331055
ep 10; batch 100/469 Loss D: 0.009129696525633335 Loss G: 7.159115314483643
ep 10; batch 150/469 Loss D: 0.008204559795558453 Loss G: 5.873725414276123
ep 10; batch 200/469 Loss D: 0.010280273854732513 Loss G: 5.8175249099731445
ep 10; batch 250/469 Loss D: 0.009507905691862106 Loss G: 5.773584842681885
ep 10; batch 300/469 Loss D: 2.8079893589019775 Loss G: 4.613335609436035
ep 10; batch 350/469 Loss D: 0.5120745897293091 Loss G: 1.6789381504058838
ep 10; batch 400/469 Loss D: 0.3465350568294525 Loss G: 2.952090263366699
ep 10; batch 450/469 Loss D: 0.429098516702652 Loss G: 2.9044694900512695
ep 11; batch 0/469 Loss D: 0.33474719524383545 Loss G: 2.123577117919922
ep 11; batch 50/469 Loss D: 0.25049692392349243 Loss G: 2.8739748001098633
ep 11; batch 100/469 Loss D: 0.34395459294319153 Loss G: 2.463099241256714
ep 11; batch 150/469 Loss D: 0.5271680951118469 Loss G: 2.1623029708862305
ep 11; batch 200/469 Loss D: 0.47195255756378174 Loss G: 2.2361440658569336
ep 11; batch 250/469 Loss D: 2.519099473953247 Loss G: 6.959024429321289
ep 11; batch 300/469 Loss D: 0.25807321071624756 Loss G: 3.2097063064575195
ep 11; batch 350/469 Loss D: 0.13557180762290955 Loss G: 3.299124002456665
ep 11; batch 400/469 Loss D: 0.044025346636772156 Loss G: 4.272496223449707
ep 11; batch 450/469 Loss D: 0.024197092279791832 Loss G: 4.943612098693848
ep 12; batch 0/469 Loss D: 0.05700772628188133 Loss G: 4.943131446838379
ep 12; batch 50/469 Loss D: 0.02191094681620598 Loss G: 5.499110221862793
ep 12; batch 100/469 Loss D: 0.02563977614045143 Loss G: 5.091275691986084
ep 12; batch 150/469 Loss D: 0.01687014102935791 Loss G: 5.932982444763184
ep 12; batch 200/469 Loss D: 0.016393788158893585 Loss G: 5.520475387573242
ep 12; batch 250/469 Loss D: 0.7661377787590027 Loss G: 1.546775221824646
ep 12; batch 300/469 Loss D: 0.6288726329803467 Loss G: 1.5122661590576172
ep 12; batch 350/469 Loss D: 0.5581820011138916 Loss G: 3.2820026874542236
ep 12; batch 400/469 Loss D: 1.1664769649505615 Loss G: 0.691653311252594
ep 12; batch 450/469 Loss D: 0.36888569593429565 Loss G: 2.6162986755371094
ep 13; batch 0/469 Loss D: 0.3324792683124542 Loss G: 2.222844123840332
ep 13; batch 50/469 Loss D: 0.5788969397544861 Loss G: 6.298195838928223
ep 13; batch 100/469 Loss D: 0.19459375739097595 Loss G: 2.7048728466033936
ep 13; batch 150/469 Loss D: 0.10582943260669708 Loss G: 3.5022287368774414
ep 13; batch 200/469 Loss D: 0.08053069561719894 Loss G: 3.788170099258423
ep 13; batch 250/469 Loss D: 0.7717230916023254 Loss G: 2.203493595123291
ep 13; batch 300/469 Loss D: 0.2528931796550751 Loss G: 3.63883638381958
ep 13; batch 350/469 Loss D: 0.08339236676692963 Loss G: 3.6089000701904297
ep 13; batch 400/469 Loss D: 0.21092499792575836 Loss G: 2.9837284088134766
ep 13; batch 450/469 Loss D: 0.038445428013801575 Loss G: 4.454034328460693
ep 14; batch 0/469 Loss D: 0.026303458958864212 Loss G: 4.745034217834473
ep 14; batch 50/469 Loss D: 0.022738482803106308 Loss G: 5.251299858093262
ep 14; batch 100/469 Loss D: 0.016661226749420166 Loss G: 5.427088260650635
ep 14; batch 150/469 Loss D: 0.014188871718943119 Loss G: 5.507170677185059
ep 14; batch 200/469 Loss D: 0.01933629997074604 Loss G: 5.378510475158691
ep 14; batch 250/469 Loss D: 0.008621224202215672 Loss G: 5.838019371032715
ep 14; batch 300/469 Loss D: 0.007459650281816721 Loss G: 5.871398448944092
ep 14; batch 350/469 Loss D: 0.0074442243203520775 Loss G: 6.1081223487854
ep 14; batch 400/469 Loss D: 0.010452792048454285 Loss G: 5.606848239898682
ep 14; batch 450/469 Loss D: 0.015714498236775398 Loss G: 5.118441104888916
ep 15; batch 0/469 Loss D: 0.011533009819686413 Loss G: 5.520183086395264
ep 15; batch 50/469 Loss D: 0.010557263158261776 Loss G: 5.789259910583496
ep 15; batch 100/469 Loss D: 0.009205515496432781 Loss G: 6.540092468261719
ep 15; batch 150/469 Loss D: 0.0076754214242100716 Loss G: 6.363622665405273
ep 15; batch 200/469 Loss D: 0.005398456938564777 Loss G: 6.469827651977539
ep 15; batch 250/469 Loss D: 0.005618053488433361 Loss G: 6.814013481140137
ep 15; batch 300/469 Loss D: 0.005054309964179993 Loss G: 6.478793144226074
ep 15; batch 350/469 Loss D: 0.0032767876982688904 Loss G: 7.099100112915039
ep 15; batch 400/469 Loss D: 0.014970965683460236 Loss G: 5.560117721557617
ep 15; batch 450/469 Loss D: 0.010740772821009159 Loss G: 6.052121639251709
ep 16; batch 0/469 Loss D: 0.008473563939332962 Loss G: 6.168562889099121
ep 16; batch 50/469 Loss D: 0.010541001334786415 Loss G: 5.767892837524414
ep 16; batch 100/469 Loss D: 0.007034521549940109 Loss G: 6.1815080642700195
ep 16; batch 150/469 Loss D: 0.015438461676239967 Loss G: 8.964953422546387
ep 16; batch 200/469 Loss D: 0.00652354396879673 Loss G: 6.646697521209717
ep 16; batch 250/469 Loss D: 0.00701667508110404 Loss G: 6.517604827880859
ep 16; batch 300/469 Loss D: 0.003171574557200074 Loss G: 7.092555999755859
ep 16; batch 350/469 Loss D: 0.21980515122413635 Loss G: 3.3964505195617676
ep 16; batch 400/469 Loss D: 0.29211142659187317 Loss G: 3.9267683029174805
ep 16; batch 450/469 Loss D: 0.1661679744720459 Loss G: 3.40801739692688
ep 17; batch 0/469 Loss D: 0.25161027908325195 Loss G: 2.4059486389160156
ep 17; batch 50/469 Loss D: 0.24621984362602234 Loss G: 4.333168983459473
ep 17; batch 100/469 Loss D: 0.16942758858203888 Loss G: 4.081494331359863
ep 17; batch 150/469 Loss D: 0.19936257600784302 Loss G: 3.2313742637634277
ep 17; batch 200/469 Loss D: 0.13399726152420044 Loss G: 4.200002670288086
ep 17; batch 250/469 Loss D: 0.196912944316864 Loss G: 3.487459659576416
ep 17; batch 300/469 Loss D: 0.18254634737968445 Loss G: 2.606510639190674
ep 17; batch 350/469 Loss D: 0.09993831813335419 Loss G: 3.538027048110962
ep 17; batch 400/469 Loss D: 10.110220909118652 Loss G: 0.4652438759803772
ep 17; batch 450/469 Loss D: 0.2025837004184723 Loss G: 2.7093467712402344
ep 18; batch 0/469 Loss D: 0.12469653785228729 Loss G: 4.042414665222168
ep 18; batch 50/469 Loss D: 0.0750548392534256 Loss G: 3.9852514266967773
ep 18; batch 100/469 Loss D: 0.03485891968011856 Loss G: 4.885221481323242
ep 18; batch 150/469 Loss D: 0.025796061381697655 Loss G: 5.229944705963135
ep 18; batch 200/469 Loss D: 0.014415163546800613 Loss G: 6.034151077270508
ep 18; batch 250/469 Loss D: 0.013660097494721413 Loss G: 6.636268138885498
ep 18; batch 300/469 Loss D: 0.013952156528830528 Loss G: 5.676086902618408
ep 18; batch 350/469 Loss D: 0.00634820107370615 Loss G: 6.773355960845947
ep 18; batch 400/469 Loss D: 0.013737009838223457 Loss G: 7.555249214172363
ep 18; batch 450/469 Loss D: 0.010785991325974464 Loss G: 5.79215145111084
ep 19; batch 0/469 Loss D: 0.007185060065239668 Loss G: 7.523918151855469
ep 19; batch 50/469 Loss D: 0.005497686564922333 Loss G: 6.766593933105469
ep 19; batch 100/469 Loss D: 0.004568344913423061 Loss G: 7.1586713790893555
ep 19; batch 150/469 Loss D: 0.4778887629508972 Loss G: 3.3407158851623535
ep 19; batch 200/469 Loss D: 0.23742061853408813 Loss G: 3.47819447517395
ep 19; batch 250/469 Loss D: 0.21565037965774536 Loss G: 2.9750943183898926
ep 19; batch 300/469 Loss D: 0.23682864010334015 Loss G: 2.539046287536621
ep 19; batch 350/469 Loss D: 0.24528293311595917 Loss G: 3.6042966842651367
ep 19; batch 400/469 Loss D: 0.1714928150177002 Loss G: 3.6256589889526367
ep 19; batch 450/469 Loss D: 0.8809894323348999 Loss G: 3.3670623302459717
Выведем наши функции потерь:
plt.figure(figsize=(8, 6))
plt.plot(gener_losses, label='Generator loss')
plt.plot(disc_losses, label='Discriminator loss')
plt.legend()
plt.ylabel('Loss')
plt.xlabel('Itertion')
plt.grid()
plt.show()
К сожалению, по функциям потерь моделей GAN трудно оценить сходимость модели. Обычно понимают, что GAN сошелся, когда обе функции потерь стабилизировались - перестали меняться. Также полезно выводить результаты работы модели во время обучения и сравнивать их с реальными изображениями. Когда качество генерированных картинок перестанет меняться, мы поймем, что модель сошлась.
Отметим, что у нас есть скачки в функциях потерь моделей, особенно это видно по функции потерь дискриминатора. По графику можно заметить, что скачки Discriminator loss происходили, когда падала функция потерь генератора. В эти моменты генератор обучался настолько хорошо, что текущая версия дискриминатор начинала сильно ошибаться. Такие скачки периодически встречаются при обучении GAN. Это нормальный процесс, особенно в начале обучения.
Посмотрим, какие картинки создавались на разных итерациях обучения нашей модели.
def grid_animation(img_list):
fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
animate = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)
HTML(animate.to_jshtml())fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
animate = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)
HTML(animate.to_jshtml())
Сравним чуть ближе настоящие и сгенерированные изображения.
def plot_images(images, label):
len_batch = len(images)
plt.figure(figsize=(20, 3))
plt.title(label)
for i in range(len(images)):
plt.subplot(1, len_batch, i+1)
original_img = images[i] / 2 + 0.5 # unnormalize
matrix_image = original_img.cpu().detach().numpy()
if matrix_image.shape[0] == 1:
image = matrix_image[0]
matrix_image = np.array([image, image, image])
plt.imshow(np.transpose(matrix_image, (1, 2, 0)))
plt.axis('off')n_images = 8
noise = torch.randn(n_images, latent_size, 1, 1, device=device)
images = model_gener(noise)
plot_images(images, label='Сгенерированные изображения')
n_images = 8
dataiter = iter(dataloader)
images, _ = dataiter.next()
images_sample = images[:n_images]
plot_images(images_sample, label='настоящие изображения')
Как мы видим, сгенерированные изображения довольно похожи на настоящие, но все еще отличимы. Чтобы получить более высокое качество, можно было бы поэкспериментировать с архитектурой сетей, lr, количеством эпох, другими параметрами нашего эксперимента и другими генеративными моделями.
Методы генерации изображений развиваются с каждым годом. В настоящее время качество моделей, использующих идеи GAN, заметно выросло по сравнению с моделями 2014–2015 гг.
Более актуальными моделями GAN является Conditional GAN, Projected GAN или CycleGAN. В Сonditional GAN мы добавляем параметр модели, при помощи которого контролируем отдельные аспекты получившегося изображения, например цвет генерируемого автомобиля. В Projected GAN мы в качестве первых слоев дискриминатора используем замороженную предобученную модель, что упрощает задачу дискриминатору. Модель CycleGAN помогает нам переносить стили из одной картинки на другую, как это изображено ниже.
Источник: https://www.kaggle.com/code/netstalker1337/pytorch-cyclegan/notebook
Также сейчас все большую роль в генерации изображений начинают играть диффузионные модели. В их основе лежит идея наложения шума на изображения и обучения модели восстанавливать исходную картинку. Постепенно мы добавляем все более сильный шум, и так в один момент диффузионная модель начинает восстанавливать изображение из чистого шума. Подробнее с этими моделями можно ознакомиться в этой статье.